# Distributed training a BERT model with Nvidia NeMo
#
# Finetunes a BERT-like model on the GLUE CoLA task. Uses the NeMo toolkit
# to train across multiple nodes, each node having a V100 GPU.
#
# Uses glue_benchmark.py script from the NeMo examples:
# https://github.com/NVIDIA/NeMo/blob/2ce45369f7ab6cd20c376d1ed393160f5e54be0c/examples/nlp/glue_benchmark/glue_benchmark.py
#
# Usage:
#   sky launch -c nemo_bert nemo_bert.yaml
#
#   # Or try on spot A100 GPUs:
#   sky launch -c nemo_bert nemo_bert.yaml --use-spot --gpus A100:1
#
#   # Terminate cluster after you're done
#   sky down nemo_bert

resources:
  accelerators: V100:1

num_nodes: 2

setup: |
  conda activate nemo
  if [ $? -eq 0 ]; then
      echo "conda env exists"
  else
      conda create -y --name nemo python==3.10.12
      conda activate nemo
  
      # Install PyTorch
      pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
      
      # Install nemo
      sudo apt-get update
      sudo apt-get install -y libsndfile1 ffmpeg
      pip install Cython
      pip install nemo_toolkit['all']
    
      # Clone the NeMo repo to get the examples
      git clone https://github.com/NVIDIA/NeMo.git
      
      # Download GLUE dataset
      wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/70e86a10fbf4ab4ec3f04c9ba82ba58f87c530bf/download_glue_data.py
      python download_glue_data.py --data_dir glue_data --tasks CoLA
  fi

run: |
  conda activate nemo
  
  # Get the number of nodes and master address from SkyPilot envvars
  num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
  master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1`
  
  # Run glue_benchmark.py
  python -m torch.distributed.run \
    --nproc_per_node=${SKYPILOT_NUM_GPUS_PER_NODE} \
    --nnodes=${num_nodes} \
    --node_rank=${SKYPILOT_NODE_RANK} \
    --master_addr=${master_addr} \
    --master_port=8008 \
    NeMo/examples/nlp/glue_benchmark/glue_benchmark.py \
    model.dataset.data_dir=glue_data/CoLA  \
    model.task_name=cola \
    trainer.max_epochs=10 \
    trainer.num_nodes=${num_nodes}

